from src.bulletEnv import BulletEnv
import os
import time
import utils
from stable_baselines3 import SAC
import utils
import pickle
import numpy as np

class MultiSim(object):

    def __init__(self, init_config, goal_config, problem_config, term_sampler,mp = False,seed=1337,nsims=2):
        print("Initializing simulators")
        self.actions = []
        self.problem_config = problem_config
        self.runbag = self.setup_bag(init_config, goal_config)
        if mp:
            robot_key = "mp_robot_"
            gui = False
        else:
            robot_key = "sim_robot_"
            gui = problem_config["simulator_gui"]

        self.envs = utils.env_utils.make_subproc_envs(num=nsims,
                                                     gui=gui,
                                                     seed=seed,
                                                     init_set=init_config,
                                                     term_set=goal_config,
                                                     term_sampler=term_sampler,
                                                     option_guide=None,
                                                     region_switch_point=None,
                                                     demo_mode=self.problem_config['debug'],
                                                     robot_config=self.problem_config['robot'],
                                                     env_path=os.path.join(self.problem_config['env_path'], 
                                                                           self.problem_config['env_name']+".stl"),
                                                     forked=True,
                                                     env_prefix=robot_key,
                                                     max_ep_len=self.problem_config["region_policy"]['max_ep_len'])
        self.envs.set_attr('env_mode','eval')
        self.obs = self.envs.reset()

    def set_init_sample_and_eval_func(self,init_sampler,eval_func,term_sampler):
        print("Resetting sampler and eval func")
        self.reset_action_log()
        self.envs.env_method("set_sampler_and_eval_func",init_sampler, eval_func, term_sampler)
    
    def sync_simulator(self, init_config):
        print("Syncing simulator")
        return all(self.envs.env_method("setRobotState",
                             sim=False,
                             pose=init_config))
    
    def reset(self):
        self.envs.set_attr('env_mode','train')
        self.obs = self.envs.reset()
        self.envs.set_attr('env_mode','eval')

    def get_collision_fn(self):
        def collision_fn(pose):
            return self.envs.env_method("collision_at_pose",pose)[0]
        return collision_fn
    
    def plot(self,pose):
        self.envs.env_method("plot",pose)

    def execute_policy(self, policy_path, option_guide, option_switch_point):
        # Run agent in eval mode
        print("Executing policy...")
        model = SAC.load(policy_path)
        self.envs.set_attr("info",utils.env_utils.create_env_info_dict())
        self.envs.env_method("update_targets",option_guide, option_switch_point)

        last_observations = {}
        action_count = 0
        max_ep_len = self.envs.get_attr("max_ep_len")[0]
        time_end = time.time()+60
        while action_count < max_ep_len and time.time() < time_end:
            actions, _states = model.predict(self.obs)

            for idx in range(len(actions)):
                if idx in last_observations.keys():
                    actions[idx] = [0.] * len(actions[idx])

            self.actions.append(actions)
            self.obs, rewards, dones, infos = self.envs.step(actions)

            for idx in range(len(dones)):
                if dones[idx] and last_observations.get(idx) is None:
                    print(action_count, idx)
                    last_observations[idx] = infos[idx]['terminal_observation']

            action_count += 1
            # if len(np.where(dones)[0]) == len(dones):
            #     break
            if len(last_observations) == len(dones):
                break
        
        # Set done to False because eval mode does not touch the done state
        self.envs.set_attr("done",False)
        self.envs.set_attr("episode_step_ctr", 0)
        self.envs.env_method("reset_target_idx")
        self.envs.env_method("reset_local_target")

        # Vectorized environments from stable baselines resets the envs if done is True
        # The last actual observation is stored in info['terminal_observation']
        # Not all environments might terminate, thus the conditional
        # curobs = []
        # for idx, info in enumerate(infos):
        #     if info.get('terminal_observation') is not None:
        #         env_term_obs = info['terminal_observation'][:3] # Use only the x,y,theta values
        #         curobs.append(info['terminal_observation'][:3])
        #     else:
        #         env_term_obs = self.obs[idx][:3]
        #         curobs.append(info['terminal_observation'][:3])

        #     self.envs.env_method("setRobotState",pose=env_term_obs,sim=False)
        
        # # set the current observations 
        # self.obs = tuple(curobs)
        curpos = self.envs.env_method("getRobotCurState")
        print(f"{action_count=} {curpos=}")
        print("Finished policy exec")

    def get_ll_config(self):
        print("Fetching LL config")
        return self.envs.env_method("getRobotCurState")
    
    def save_actions(self,logtype,logdir):
        # os.makedirs(os.path.join(self.problem_config['action_logs'],self.problem_config['robot']['name'],self.problem_config['env_name']),exist_ok=True)
        # fname = os.path.join(self.problem_config['action_logs'],
        #                      self.problem_config['robot']['name'],
        #                      self.problem_config['env_name'],
        #                      self.bag_prefix+"_"+logtype+"_"+time.strftime('%d%y%h_%H%M%S')+'.pkl')
        fname = os.path.join(logdir,
                             self.bag_prefix+"_"+logtype+"_"+time.strftime('%d%y%h_%H%M%S')+'.pkl')

        with open(fname, 'wb') as pickleout:
            pickle.dump(self.runbag, pickleout)
    
    def reset_action_log(self):
        self.actions = []
    
    def stash_actions(self):
        self.runbag['run_actions'].extend(self.actions)
        self.reset_action_log()

    def setup_bag(self,init_config, goal_config,prefix=None):
        if prefix is not None:
            self.bag_prefix = prefix
        self.cur_action_index = 0
        return {"init_config":init_config,
                "goal_config":goal_config,
                "robot_name":self.problem_config['robot']['name'],
                "robot_file":self.problem_config['robot']['model_path'],
                "env_name":self.problem_config['env_name'],
                "run_actions":[]}
    
    def config_reached(self, goal_config, tolerance):
        '''
        Check if each simulator env has reached the termination condition or not
        If even 1 passes, report success.
        '''
        simconfig = self.get_ll_config()
        reached = False
        for i,config in enumerate(simconfig):
            if np.sqrt((config[0]-goal_config[0])**2 + (config[1]-goal_config[1])**2) < tolerance:
                print("Simulator passed: #{}".format(i+1))
                reached = True
            else:
                print("Simulator failed to reach goal: #{}".format(i+1))
        
        return reached
